{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|\n",
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|\n",
      "|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|\n",
      "|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|\n",
      "+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql import SparkSession\n",
    "from pyspark.ml.classification import LogisticRegression\n",
    "\n",
    "spark = SparkSession.builder.appName('titanic_logreg').getOrCreate()\n",
    "df = spark.read.csv('titanic.csv', inferSchema = True, header = True)\n",
    "df.show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- PassengerId: integer (nullable = true)\n",
      " |-- Survived: integer (nullable = true)\n",
      " |-- Pclass: integer (nullable = true)\n",
      " |-- Name: string (nullable = true)\n",
      " |-- Sex: string (nullable = true)\n",
      " |-- Age: double (nullable = true)\n",
      " |-- SibSp: integer (nullable = true)\n",
      " |-- Parch: integer (nullable = true)\n",
      " |-- Ticket: string (nullable = true)\n",
      " |-- Fare: double (nullable = true)\n",
      " |-- Cabin: string (nullable = true)\n",
      " |-- Embarked: string (nullable = true)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "df.printSchema()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['PassengerId',\n",
       " 'Survived',\n",
       " 'Pclass',\n",
       " 'Name',\n",
       " 'Sex',\n",
       " 'Age',\n",
       " 'SibSp',\n",
       " 'Parch',\n",
       " 'Ticket',\n",
       " 'Fare',\n",
       " 'Cabin',\n",
       " 'Embarked']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_col = df.select(['Survived','Pclass','Sex','Age','SibSp','Parch','Fare','Embarked'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_data = my_col.na.drop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import (VectorAssembler, StringIndexer, VectorIndexer, OneHotEncoder)\n",
    "\n",
    "gender_indexer = StringIndexer(inputCol = 'Sex', outputCol = 'SexIndex')\n",
    "gender_encoder = OneHotEncoder(inputCol='SexIndex', outputCol = 'SexVec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "embark_indexer = StringIndexer(inputCol = 'Embarked', outputCol = 'EmbarkIndex')\n",
    "embark_encoder = OneHotEncoder(inputCol = 'EmbarkIndex', outputCol = 'EmbarkVec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "assembler = VectorAssembler(inputCols = ['Pclass', 'SexVec', 'Age', 'SibSp', 'Parch', 'Fare', 'EmbarkVec'], outputCol = 'features')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "\n",
    "log_reg = LogisticRegression(featuresCol = 'features', labelCol = 'Survived')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = Pipeline(stages = [gender_indexer, embark_indexer, \n",
    "                             gender_encoder, embark_encoder,\n",
    "                             assembler, log_reg])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = final_data.randomSplit([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_model = pipeline.fit(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = fit_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+--------+\n",
      "|prediction|Survived|\n",
      "+----------+--------+\n",
      "|       1.0|       0|\n",
      "|       1.0|       0|\n",
      "|       1.0|       0|\n",
      "+----------+--------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results.select('prediction', 'Survived').show(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8512512512512513"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "eval = BinaryClassificationEvaluator(rawPredictionCol = 'rawPrediction', labelCol = 'Survived')\n",
    "AUC = eval.evaluate(results)\n",
    "AUC"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
